use std::ops::{Index, IndexMut};

use cgmath::{vec2, Vector2};

use crate::math::Rect2D;

#[repr(C)]
#[derive(Clone, Copy)]
pub struct Pixel {
    r: u8,
    g: u8,
    b: u8,
    a: u8,
}

impl Pixel {
    pub const TRANSPARENT: Self = Self::new(0, 0, 0, 0);
    pub const BLACK: Self = Self::new(0, 0, 0, 255);
    pub const BLUE: Self = Self::new(0, 0, 255, 255);
    pub const GREEN: Self = Self::new(0, 255, 0, 255);
    pub const CYAN: Self = Self::new(0, 255, 255, 255);
    pub const RED: Self = Self::new(255, 0, 0, 255);
    pub const PURPLE: Self = Self::new(255, 0, 255, 255);
    pub const YELLOW: Self = Self::new(255, 255, 0, 255);
    pub const WHITE: Self = Self::new(255, 255, 255, 255);

    pub const fn new(r: u8, g: u8, b: u8, a: u8) -> Self {
        Self { r, g, b, a }
    }

    pub const fn from_rgba_array(rgba: [u8; 4]) -> Self {
        Self {
            r: rgba[0],
            g: rgba[1],
            b: rgba[2],
            a: rgba[3],
        }
    }
}

//impl AsBytes for Pixel {}
// TODO: Which endian the computer have? if we could know, simply use unsafe { transmute(value) }
impl From<u32> for Pixel {
    fn from(value: u32) -> Self {
        let mask = 0b11u32;
        let a = (value & mask) as u8;
        let b = ((value >> 8) & mask) as u8;
        let g = ((value >> 16) & mask) as u8;
        let r = ((value >> 24) & mask) as u8;

        Self { r, g, b, a }
    }
}

pub struct Pixmap {
    pub pixels: Vec<Pixel>,
    width: u32,
    height: u32,
}

impl Pixmap {
    pub fn new(width: u32, height: u32) -> Self {
        Self {
            pixels: vec![Pixel::TRANSPARENT; (width * height) as usize],
            width,
            height,
        }
    }

    pub fn from_pixels<I: IntoIterator<Item = Pixel>>(pixels: I, width: u32, height: u32) -> Self {
        Self {
            pixels: pixels.into_iter().collect(),
            width,
            height,
        }
    }

    pub fn from_pixel_array_2d(pixels: &[&[Pixel]]) -> Self {
        let height = pixels.len() as u32;
        let width = pixels[0].len() as u32;
        let pixels = pixels
            .iter()
            .flat_map(|l| l.iter())
            .copied()
            .collect::<Vec<_>>();

        Self {
            pixels,
            width,
            height,
        }
    }

    pub fn create_region(&self, x: u32, y: u32, width: u32, height: u32) -> Rect2D<f32> {
        let pixmap_width = self.width as f32;
        let pixmap_height = self.height as f32;
        let left = x as f32 / pixmap_width;
        let top = y as f32 / pixmap_height;
        let right = (x + width) as f32 / pixmap_width;
        let bottom = (y + height) as f32 / pixmap_height;

        Rect2D::new(left, right, bottom, top)
    }

    pub fn extend_width(&mut self, amount: u32) {
        let new_width = self.width + amount;
        let mut new_storage = vec![Pixel::TRANSPARENT; (self.height * new_width) as usize];
        for y in 0..self.height {
            for x in 0..self.width {
                new_storage[(x + new_width * y) as usize] =
                    self.pixels[(x + self.width * y) as usize];
            }
        }
        self.width = new_width;
        self.pixels = new_storage;
    }

    pub fn extend_height(&mut self, amount: u32) {
        self.pixels
            .extend(std::iter::repeat(Pixel::TRANSPARENT).take((amount * self.width) as usize));
        self.height += amount;
    }

    pub fn width(&self) -> u32 {
        self.width
    }

    pub fn height(&self) -> u32 {
        self.height
    }

    pub fn size(&self) -> Vector2<u32> {
        vec2(self.width, self.height)
    }

    pub fn extent(&self) -> wgpu::Extent3d {
        wgpu::Extent3d {
            width: self.width,
            height: self.height,
            depth_or_array_layers: 1,
        }
    }

    pub fn data_layout(&self) -> wgpu::ImageDataLayout {
        wgpu::ImageDataLayout {
            offset: 0,
            bytes_per_row: Some(std::mem::size_of::<Pixel>() as u32 * self.width),
            rows_per_image: Some(self.height),
        }
    }
}

impl Index<(u32, u32)> for Pixmap {
    type Output = Pixel;

    fn index(&self, index: (u32, u32)) -> &Self::Output {
        &self.pixels[(index.0 + self.width * index.1) as usize]
    }
}

impl IndexMut<(u32, u32)> for Pixmap {
    fn index_mut(&mut self, index: (u32, u32)) -> &mut Self::Output {
        &mut self.pixels[(index.0 + self.width * index.1) as usize]
    }
}
